import torch
import torch.nn as nn
from dm03_decoder import *
import torch.nn.functional as F

class Generator(nn.Module):
    def __init__(self, d_model, vocab_size):
        super(Generator, self).__init__()

        self.out = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        return F.log_softmax(self.out(x), dim=-1)